
import torch
import torch.nn as nn
import os
import logging
from typing import Tuple
from transformers import AutoTokenizer

import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config import get_params

logger = logging.getLogger()
params = get_params()
if 'bert' in params.backbone:
    auto_tokenizer = AutoTokenizer.from_pretrained(params.backbone)
else:
    auto_tokenizer = None
pad_token_label_id = nn.CrossEntropyLoss().ignore_index
max_seq_length = params.max_seq_length

class Dataset(torch.utils.data.Dataset):
    def __init__(self, inputs, ys):
        self.X = inputs
        self.y = ys
    
    def __getitem__(self, index):
        return index, self.X[index], self.y[index]

    def __len__(self):
        return len(self.X)

def combine_two_batch(X1, X2, y1, y2):
    '''
        X1 (bs1, sq1) or (bs1, c, h, w);
        X2 (bs2, sq2) or (bs2, c, h, w);
        y1 (bs1,) or (bs1, sq1);
        y2 (bs2,) or (bs2, sq2);
    '''
    
    if X1.shape[1:]==X2.shape[1:]:
        return torch.cat((X1,X2),dim=0), torch.cat((y1,y2),dim=0)

    len_list = []
    len_list.extend([len(x) for x in X1])
    len_list.extend([len(x) for x in X2])
    max_len = max(len_list)
    bs1 = X1.shape[0]
    bs2 = X2.shape[0]

    if len(y1.shape)==1:
        combine_X = torch.LongTensor(len(len_list),max_len).fill_(auto_tokenizer.pad_token_id).to(X1.device)

        for i in range(bs1):
            length = len_list[i]
            combine_X[i,:length] = X1[i]

        for i in range(bs2):
            length = len_list[bs1+i]
            combine_X[bs1+i,:length] = X2[i]

        combine_y = torch.cat((y1,y2))

        return combine_X, combine_y
    
    if len(y1.shape)==2:
        combine_X = torch.LongTensor(len(len_list),max_len).fill_(auto_tokenizer.pad_token_id).to(X1.device)
        combine_y = torch.LongTensor(len(len_list),max_len).fill_(pad_token_label_id).to(X1.device)

        for i in range(bs1):
            length = len_list[i]
            combine_X[i,:length] = X1[i]
            combine_y[i,:length] = y1[i]

        for i in range(bs2):
            length = len_list[bs1+i]
            combine_X[bs1+i,:length] = X2[i]
            combine_y[bs1+i,:length] = y2[i]

        return combine_X, combine_y

class Continual_Dataset():
    '''
        Continual Learning with multiple datasets
    '''
    def __init__(self):
        pass

if __name__ == "__main__":
    pass
